require(magrittr)
require(knitr)
require(dplyr)
require(tidyr)
require(tibble)
require(ggplot2)
require(Seurat)
require(cowplot)
require(kableExtra)
require(readr)
require(readxl)
require(slingshot)
require(SingleCellExperiment)
require(mgcv)
require(slingshot)
require(tradeSeq)
require(patchwork)
source("seurat_helper_functions.R")
parent.directory <- "intermediate/"

cluster_heatmap <- function(my_data_frame, clusters = 3, max_zed = 3, plot_rownames = FALSE, color_brewer = "Set1", label_features = NULL, y_text_size = 10, get_clusters = FALSE, dist_method = "euclidean"){
  my_data_frame <- as.data.frame(t(scale(t(my_data_frame))))
  my_data_frame[my_data_frame > max_zed] <- max_zed
  my_data_frame[my_data_frame < -max_zed] <- -max_zed
  hclust_results <- hclust(dist(my_data_frame, method = dist_method))
  gene_order <- hclust_results$labels[hclust_results$order]
  cluster_results <- data.frame(cluster= cutree(hclust_results, k = clusters))
  cluster_results$cluster <- as.factor(cluster_results$cluster)
  cluster_results$gene <- rownames(cluster_results)
  cluster_results$gene <- factor(cluster_results$gene, levels = gene_order)
  if(get_clusters){
    return(cluster_results)
  }
  my_data_frame_long <- pivot_longer(my_data_frame %>% rownames_to_column(var = "gene"), -gene, values_to = "Z_sc", names_to = "pseudotime")
  my_data_frame_long$gene <- factor(my_data_frame_long$gene, levels = gene_order)
  my_data_frame_long$pseudotime <- factor(my_data_frame_long$pseudotime, levels = colnames(my_data_frame))
  my_labels <- ifelse(gene_order %in% label_features, gene_order, "")
  g_heatmap <- ggplot(my_data_frame_long, aes(x = pseudotime, y = gene, fill = Z_sc))+
    geom_tile()+
    scale_fill_viridis_c()+
    ylab(NULL)+
    scale_y_discrete(position = "right", expand = c(0.01, 0.01), labels = my_labels, breaks = my_labels)+
    theme(axis.text.x = element_blank(), axis.ticks.x = element_blank() , plot.margin = margin(l = 0), panel.background = element_rect(color = "white"), axis.text.y = element_text(size = y_text_size), legend.position = "bottom")
  g_dendro <- ggdendro::ggdendrogram(data = hclust_results)+
    scale_y_reverse(expand = c(0.01,0.01))+
    coord_flip()+
    theme(axis.text.y = element_blank(), axis.text.x = element_blank(), plot.margin = margin(r = 0))+scale_x_discrete(expand = c(0.01,0.01))
  g_clusters <- ggplot(cluster_results, aes(x = 1, y = gene, fill = cluster))+
    geom_tile()+
    scale_fill_manual(values = RColorBrewer::brewer.pal(clusters,color_brewer))+
    scale_y_discrete(position = "right", expand = c(0.01, 0.01), labels = NULL)+
    theme(axis.text.x = element_blank(), axis.ticks = element_blank() , plot.margin = margin(l = 0, r = 0), panel.background = element_rect(color = "white"), legend.position = "bottom")+
    ylab(NULL)+
    xlab("")
  (g_dendro | g_clusters | g_heatmap) + patchwork::plot_layout(widths = c(0.9,0.1,2), guides = "collect")&theme(legend.position = "bottom", legend.text = element_text(size = y_text_size * 0.75), legend.title = element_text(size = y_text_size * 0.75), legend.key.size = unit(0.2, "cm"))
}

gea_dotplot <- function(enrichment_results, p_adjust = 0.01, fill_color = "black"){
  temp_results <- enrichment_results@result
  temp_results %>%
    mutate(ID = gsub("HALLMARK_", "HM_", ID)) %>%
    mutate(ID = gsub("_", "\n", ID)) %>%
    separate(col = GeneRatio, sep = "/", into = c("Count", "Total")) %>%
    mutate(Count = as.numeric(Count), Total = as.numeric(Total)) %>%
    mutate(numeric_ratio = Count/Total) %>%
    arrange(numeric_ratio) %>%
    filter(p.adjust < p_adjust) %>%
    mutate(ID = factor(ID, levels = ID)) %>%
    ggplot(aes(x = numeric_ratio, y = ID, size = Count))+
    geom_point(color = fill_color)+
    theme_cowplot()+
    ylab(NULL)+
    xlab("Gene Ratio")
}


plot_gam_auc <- function(seurat_object, smoothed_values, chosen_feature, pseudotime_var = "pt_target_curve_trimmed", cluster_choice = "cluster_names", assay = "regulons"){
  smoothed_values_subset <- dplyr::filter(smoothed_values, geneset == chosen_feature)
  feature_val <- FetchData(seurat_object, vars = chosen_feature)
  temp_df <- FetchData(seurat_object, vars = c(pseudotime_var, cluster_choice))
  temp_df <- cbind(feature_val, temp_df)
  colnames(temp_df) <- c("feature_val", "pseudotime", "cluster")
  temp_df <- dplyr::filter(temp_df, !is.na(pseudotime))
  ggplot(temp_df, aes(x = pseudotime, y = feature_val, color = cluster))+
    geom_point()+
    geom_path(data = smoothed_values_subset, aes(x = impute_pseudotime, y = impute_auc, group = geneset), inherit.aes = FALSE, size = 2)+
    cowplot::theme_cowplot()
}

A. Non-GC and Non-PC extraction

A.1. Getting cell identifiers from large Seurat object

load(paste0(parent.directory,"sct_seurat_gcpc_named_clusters.RData"))
load(paste0(parent.directory,"sct_reguloned.RData"))
B.subset <- B.combined[,colnames(B.gcpc)]
B.gcpc[["regulons"]] <- GetAssay(B.subset, assay = c("regulons"))
B.gcpc[["genesetAUC"]] <- GetAssay(B.subset, assay = c("genesetAUC"))
B.gcpc[["btmAUC"]] <- GetAssay(B.subset, assay = c("btmAUC"))
B.gcpc[["hmAUC"]] <- GetAssay(B.subset, assay = c("hmAUC"))
rm(B.combined, B.subset)

#get lineages and curves from slingshot
target_idents <- levels(Idents(B.gcpc))
target_idents <- target_idents[!(target_idents %in% c("GC IgA", "ASC IgM", "ASC IgG"))]
target_subset <- subset(B.gcpc, idents = target_idents)
cells_to_keep <- Embeddings(target_subset, "pca_UMAP") %>% as.data.frame %>% dplyr::filter(pcaumap_1 > -5) %>% rownames()
target_subset <- subset(target_subset, cells = cells_to_keep)

gcpc.lineages <- slingshot::getLineages(data = Embeddings(target_subset, "pca_UMAP"), clusterLabels = target_subset$cluster_names, start.clus = "GC 1", end.clus = c("GC LMO2", "LZ", "DZ 4"))
gcpc.curves <- slingshot::getCurves(gcpc.lineages, extend = "n")

#add pseudotime variables to Seurat object 
curve_pseudotimes <- slingshot::slingPseudotime(gcpc.curves)
target_subset <- AddMetaData(target_subset, metadata =  curve_pseudotimes, col.name = paste0("pt_", colnames(curve_pseudotimes)))
B.gcpc <- AddMetaData(B.gcpc, metadata =  curve_pseudotimes, col.name = paste0("pt_", colnames(curve_pseudotimes)))

save(B.gcpc,
     target_subset,
     gcpc.curves,
     gcpc.lineages,
     file = paste0(parent.directory, "sct_slingshot_results.RData"))

A.2 Visualizing tree and trajectories

load(file = paste0(parent.directory, "sct_slingshot_results.RData"))
plot(Embeddings(B.gcpc, "pca_UMAP"), col = RColorBrewer::brewer.pal(9,"Set1")[2], asp = 1, pch = 16)
lines(gcpc.lineages, lwd = 3, col = 'black')

base_dimplot <- DimPlot(B.gcpc, label = TRUE, label.box = TRUE, repel = TRUE, label.size = 3)+NoLegend()
num_curves <- length(gcpc.curves@curves)
slingshot_path_mappings <- lapply(1:num_curves, function(i){
  curve_data <- as.data.frame(gcpc.curves@curves[[i]]$s[gcpc.curves@curves[[i]]$ord,])
  geom_path(data = curve_data, aes(x = pcaumap_1, y = pcaumap_2), inherit.aes = FALSE, color = "black", size = 1)
})
base_dimplot_paths <- base_dimplot
base_dimplot_paths$layers <- c(base_dimplot_paths$layers[[1]], slingshot_path_mappings, base_dimplot_paths$layers[[2]] )
base_dimplot_paths

target_curve <- "pt_curve4"
target_curve_index <- 4
B.gcpc$pt_target_curve <- ifelse(B.gcpc$cluster_names %in% c("GC LMO2", "GC 1"), B.gcpc@meta.data[,target_curve], NA)

pt_dimplot <- FeaturePlot(B.gcpc, features = "pt_target_curve")+scale_color_viridis_c("pseudotime")+slingshot_path_mappings[[target_curve_index]]+ggtitle("")
## Scale for 'colour' is already present. Adding another scale for 'colour',
## which will replace the existing scale.
pt_dimplot

pseudotime_cutoff <- 2
B.gcpc$pt_target_curve_trimmed <- ifelse(B.gcpc$pt_target_curve > pseudotime_cutoff, B.gcpc$pt_target_curve, NA)

pt_dimplot <- FeaturePlot(B.gcpc, features = "pt_target_curve_trimmed")+scale_color_viridis_c("pseudotime")+slingshot_path_mappings[[target_curve_index]]+ggtitle("")
## Scale for 'colour' is already present. Adding another scale for 'colour',
## which will replace the existing scale.
pt_dimplot

A.3 Pseudotemporal modeling

my_pt_1 <- B.gcpc@meta.data[,c("cluster_names", "pt_target_curve_trimmed")] %>%
  set_colnames(c("cluster", "t1")) %>%
  dplyr::filter((cluster == c("GC LMO2", "GC 1"))) %>%
  rownames_to_column(var = "cellID") %>%
  filter(!is.na(t1)) %>%
  dplyr::left_join(GetAssayData(B.gcpc, assay = "regulons") %>% t %>% as.data.frame %>% rownames_to_column(var = "cellID"), by = "cellID") %>%
  dplyr::left_join(GetAssayData(B.gcpc, assay = "genesetAUC") %>% t %>% as.data.frame %>% rownames_to_column(var = "cellID"), by = "cellID") %>%
    dplyr::left_join(GetAssayData(B.gcpc, assay = "btmAUC") %>% t %>% as.data.frame %>% rownames_to_column(var = "cellID"), by = "cellID") %>%
      dplyr::left_join(GetAssayData(B.gcpc, assay = "hmAUC") %>% t %>% as.data.frame %>% rownames_to_column(var = "cellID"), by = "cellID") %>%
  dplyr::arrange(t1)
my_pt_1_long <- pivot_longer(my_pt_1, -c(cellID, cluster, t1), values_to = "AUC", names_to = "geneset")

my_genesets_1 <- colnames(my_pt_1)[!(colnames(my_pt_1) %in% c("cellID", "cluster", "t1"))]

#run GAMs for each regulon
my_gams_1 <- lapply(1:length(my_genesets_1), function(i){
  temp_df <- my_pt_1[,c("t1", my_genesets_1[i])]
  if(sum(temp_df[,2]) == 0) return(NULL)
  colnames(temp_df)[2] <- "y"
  temp_model <- mgcv::gam(y ~ s(t1, k = 3), data=temp_df, method = "REML")
  return(temp_model)
})
names(my_gams_1) <- my_genesets_1
my_gams_1 <- my_gams_1[lengths(my_gams_1) > 0]

#format the GAM results as desired
my_gam_results_1 <- lapply(1:length(my_gams_1), function(i){
  geneset_name <- names(my_gams_1)[i]
  if(is.null(my_gams_1[[i]])) return(NULL)
  target_geneset <- names(my_gams_1)[[i]]
  x_new <- seq(min(my_pt_1$t1), max(my_pt_1$t1), length.out = 100)
  y_pred <- predict(my_gams_1[[i]], data.frame(t1 = x_new))
  largest_diff_real = diff(range(my_pt_1[,target_geneset]))
  largest_diff_pred = diff(range(y_pred))
  standardized_diff <- largest_diff_pred/mean(y_pred)
  y_var <- stats::var(y_pred)
  temp_model <- my_gams_1[[i]]
  p <- summary(temp_model)$s.table[,"p-value"]
  #anova_result <- anova.gam(temp_model)$s.table[,4]
  p.adj <- p.adjust(p, method = "BH", n = length(my_genesets_1))
  #anova.adj <- p.adjust(anova_result, method = "BH", n = length(my_genesets_1))
  data.frame(geneset = geneset_name, p_value = p, p_adj = p.adj, AUC_r_actual = largest_diff_real, AUC_r_pred = largest_diff_pred, mean_y = mean(y_pred), max_y = max(y_pred), var_y = y_var, sdiff_y = standardized_diff)
}) %>% do.call(rbind, .)

min_pt <- min(my_pt_1$t1)
max_pt <- max(my_pt_1$t1)
imputed_df_1 <- lapply(1:length(my_gams_1), function(i){
  x_new <- seq(min_pt, max_pt, length.out = 100)
  y_pred <- predict(my_gams_1[[i]], data.frame(t1 = x_new))
  if(length(y_pred) < 1){
    return(data.frame())
  }
  y_pred_scaled <- scale(y_pred)
  data.frame(geneset = names(my_gams_1)[i],
             impute_pseudotime = x_new,
             impute_auc = y_pred,
             impute_auc_scale = y_pred_scaled)
}) %>% do.call(rbind, .)

#assoc_gam_1 <- tradeSeq::associationTest(my_gams_1)
#svend_gam_1 <- tradeSeq::startVsEndTest(my_gams_1)
#p_value_choice_genesets <- 0.00001
## best_genesets_1 <- assoc_gam_1 %>%
#   tibble::rownames_to_column(var = "geneset") %>%
#   dplyr::filter(pvalue < p_value_choice_genesets) %>%
#   dplyr::pull(geneset)

p_value_choice_genesets <- 0.00001


my_gam_results_1 %>% 
    dplyr::filter(p_adj < p_value_choice_genesets) %>%
    dplyr::arrange(AUC_r_pred) %>%
    dplyr::mutate(geneset = factor(geneset, levels = rev(geneset))) %>%
    dplyr::pull(geneset) %>% as.character() -> best_genesets_1

imputed_df_1_wide <- imputed_df_1 %>%
  dplyr::filter(geneset %in% best_genesets_1) %>%
  pivot_wider(id_cols = geneset,
              names_from = impute_pseudotime, 
              values_from = impute_auc)  %>%
  column_to_rownames(var = "geneset")
colnames(imputed_df_1_wide) <- paste0("pt_", 1:ncol(imputed_df_1_wide))

save(B.gcpc,
     my_gams_1,
     my_gam_results_1,
     imputed_df_1,
     imputed_df_1_wide,
     best_genesets_1,
     my_pt_1,
     my_pt_1_long,
     file = paste0(parent.directory, "sct_slingshot_modeling_results.RData"))

B. Data visualization

B.1 slingshot visualization

We can now visualize the MST and the principal curves given by slingshot. We also show the pseudotime ordering on the Seurat object.

load(file = paste0(parent.directory, "sct_slingshot_modeling_results.RData"))

B.2 Heatmap of scaled AUC values

chosen_metric <- "manhattan"
chosen_clusters <- 6

best_genesets_1_regulons <- grep("regulon", best_genesets_1, value = TRUE)
best_genesets_1_nonregulons <- grep("regulon", best_genesets_1, value = TRUE,  invert = TRUE)


cluster_heatmap(imputed_df_1_wide[best_genesets_1,], max_zed = 3, clusters = chosen_clusters, label_features = rownames(imputed_df_1_wide), y_text_size = 7, color_brewer = "Set3", dist_method = chosen_metric)
## Scale for 'y' is already present. Adding another scale for 'y', which will
## replace the existing scale.
## Scale for 'x' is already present. Adding another scale for 'x', which will
## replace the existing scale.

cluster_heatmap(imputed_df_1_wide[best_genesets_1_regulons,], max_zed = 3, clusters = chosen_clusters, label_features = rownames(imputed_df_1_wide), y_text_size = 7, color_brewer = "Set3", dist_method = chosen_metric)
## Scale for 'y' is already present. Adding another scale for 'y', which will
## replace the existing scale.
## Scale for 'x' is already present. Adding another scale for 'x', which will
## replace the existing scale.

cluster_heatmap(imputed_df_1_wide[best_genesets_1_nonregulons,], max_zed = 3, clusters = chosen_clusters, label_features = rownames(imputed_df_1_wide), y_text_size = 7, color_brewer = "Set3", dist_method = chosen_metric)
## Scale for 'y' is already present. Adding another scale for 'y', which will
## replace the existing scale.
## Scale for 'x' is already present. Adding another scale for 'x', which will
## replace the existing scale.

chosen_features <- c("IRF8-regulon", "PAX5-regulon", "XBP1-regulon", "CREB3-regulon", "CREB3L2-regulon", "ATF6-regulon", "KLF13-regulon", "CUX1-regulon", "SPI1-regulon", "YY1-regulon", "CTCF-regulon")

cluster_heatmap(imputed_df_1_wide[best_genesets_1_regulons,], max_zed = 3, clusters = chosen_clusters, label_features = chosen_features, y_text_size = 7, color_brewer = "Set3", dist_method = chosen_metric)
## Scale for 'y' is already present. Adding another scale for 'y', which will
## replace the existing scale.
## Scale for 'x' is already present. Adding another scale for 'x', which will
## replace the existing scale.

quick_multi <- lapply(seq_along(chosen_features), function(i){
plot_gam_auc(B.gcpc, imputed_df_1, chosen_feature = chosen_features[i])+ggtitle(chosen_features[i])+NoLegend()+xlab(NULL)+ylab(NULL)
}) %>% cowplot::plot_grid(plotlist = ., ncol = 4)
## Warning: Could not find IRF8-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find PAX5-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find XBP1-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find CREB3-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find CREB3L2-regulon in the default search locations, found
## in regulons assay instead
## Warning: Could not find ATF6-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find KLF13-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find CUX1-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find SPI1-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find YY1-regulon in the default search locations, found in
## regulons assay instead
## Warning: Could not find CTCF-regulon in the default search locations, found in
## regulons assay instead
quick_multi

B.3 Gene by gene tradeSeq

target_cells <- my_pt_1$cellID
count_matrix <- target_subset[["RNA"]]@counts[,target_cells]
curve_pseudotimes <- data.frame(target_curve = my_pt_1$t1, row.names = my_pt_1$cellID)
cell_weights <- data.frame(target_curve = rep(1, length(my_pt_1$cellID)), row.names = my_pt_1$cellID)
target_genes <- rownames(count_matrix[rowSums(count_matrix) > 5,])
B_sce <- fitGAM(counts = count_matrix, pseudotime = curve_pseudotimes, cellWeights = cell_weights, verbose = TRUE, nknots = 3, genes = target_genes)

assoRes <- associationTest(B_sce, lineages = FALSE)
startRes <- startVsEndTest(B_sce, lineages = FALSE)

assoRes <- assoRes %>% rownames_to_column(var = "gene")
startRes <- startRes %>% rownames_to_column(var = "gene")

save(B_sce, assoRes, startRes, target_genes, curve_pseudotimes, cell_weights, file = paste0(parent.directory, "sct_slingshot_gene_results.RData"))

B.3 Cluster visualization by cluster

load(file = paste0(parent.directory, "sct_slingshot_gene_results.RData"))

pt_associated_results <- assoRes %>% dplyr::filter(pvalue < 0.0001)

predicted_smoothers <- lapply(1:length(pt_associated_results$gene), function(i){
  predict_gene <- pt_associated_results[i,"gene"]
  ysmooth <- predictSmooth(models = B_sce, gene = predict_gene, nPoints = 40)
  return(ysmooth)
}) %>% do.call(rbind, .)

predicted_smoothers_wide <- predicted_smoothers %>%
  pivot_wider(id_cols = gene, names_from = time,values_from = yhat) %>%
  column_to_rownames(var = "gene")
colnames(predicted_smoothers_wide) <- paste0("pt_", 1:ncol(predicted_smoothers_wide))

num_clusters <- 6
gene_heatmap <- cluster_heatmap(predicted_smoothers_wide[,], max_zed = 2, clusters = num_clusters, label_features = c("IGHG3", "IGHM", "MZB1", "IRF8", "PAX5", "PRDM1", "LTB", "RGS13", "JCHAIN", "LMO2", "SSR4"), color_brewer = "Set1", y_text_size = 7, dist_method = "manhattan")
cluster_palette <- data.frame(cluster = 1:num_clusters, colors = RColorBrewer::brewer.pal(num_clusters, "Set1"))
gene_heatmap

# plot_regulon_model_1 <- function(x = regulon){
#   x_new <- seq(min(my_pt_1$t1), max(my_pt_1$t1), length.out = length(my_pt_1$t1))
#   y_pred <- predict(my_gams_1[[x]], data.frame(t1 = x_new))
#   model_df <- data.frame(t1 = x_new, AUC = y_pred, derivation = "model")
#   real_df <- my_pt_1_long %>% filter(geneset == x) %>% select(t1, AUC) %>% mutate(derivation = "real")
#   ggplot(real_df, aes(x = t1, y = AUC))+
#     geom_point(color = "black")+
#     geom_line(data = model_df, aes(x = t1, y = AUC), color = "red")
# }